from State.buffer import VectorGCReplayBufferManager
from State.her_buffer import VectorGCHindsightReplayBufferManager
from State.collector import GCCollector


def initialize_data(config, policy, dynamics, single_env, env, test_env, norm, train=True):
    
    goal_conditioned = config.train.n_steps_per_goal > 0
    test_collector = GCCollector(policy, dynamics, test_env, 
                                 normalize_goal= norm.normalize_goal, normalize_obs= norm.normalize_obs, 
                                  exploration_noise=False, name="test", root_dir=config.exp_path,
                                  save_gif_num=config.save.save_gif_num, num_factors=config.num_factors,
                                  render=config.env.render, goal_conditioned=goal_conditioned)
    if not train:
        return test_collector

    if config.data.her.separate_her:
        buffer = VectorGCHindsightReplayBufferManager(
            single_env,
            config.data.buffer_len,
            buffer_num=len(env),
            use_her=config.data.her.use_her,
            her_use_count_select_goal=config.data.her.use_count_select_goal,
            horizon=config.policy.reward.timeout,
            future_k=config.data.her.future_k,
            alpha=config.data.prio.alpha,
            beta=config.data.prio.beta,
            weight_norm=config.data.prio.weight_norm,
            dynamics_per_priority_scale=config.data.prio.dynamics.priority_scale,
            dynamics_per_update_count_scale=config.data.prio.dynamics.update_count_scale,
            dynamics_per_change_count_scale=config.data.prio.dynamics.change_count_scale,
            policy_per_td_error_scale=config.data.prio.td_error_scale,
            policy_per_graph_count_scale=config.data.prio.graph_count_scale,
            count_threshold_for_valid_graph=config.data.count_threshold_for_valid_graph,
            decay_window=config.data.prio.decay_window,
            decay_rate=config.data.prio.decay_rate,
            max_prev_decay=config.data.prio.max_prev_decay,
            her_ratio = config.data.her.her_ratio,
            use_prio = config.data.prio.prio,
        )
    else:
        buffer = VectorGCReplayBufferManager(
            single_env,
            config.data.buffer_len,
            buffer_num=len(env),
            use_her=config.data.her.use_her,
            her_use_count_select_goal=config.data.her.use_count_select_goal,
            horizon=config.policy.reward.timeout,
            future_k=config.data.her.future_k,
            alpha=config.data.prio.alpha,
            beta=config.data.prio.beta,
            weight_norm=config.data.prio.weight_norm,
            dynamics_per_priority_scale=config.data.prio.dynamics.priority_scale,
            dynamics_per_update_count_scale=config.data.prio.dynamics.update_count_scale,
            dynamics_per_change_count_scale=config.data.prio.dynamics.change_count_scale,
            policy_per_td_error_scale=config.data.prio.td_error_scale,
            policy_per_graph_count_scale=config.data.prio.graph_count_scale,
            count_threshold_for_valid_graph=config.data.count_threshold_for_valid_graph,
            decay_window=config.data.prio.decay_window,
            decay_rate=config.data.prio.decay_rate,
            max_prev_decay=config.data.prio.max_prev_decay,
            use_prio = config.data.prio.prio,
        )

    # collectors
    train_collector = GCCollector(policy, dynamics, env, normalize_goal= norm.normalize_goal, normalize_obs= norm.normalize_obs, buffer=buffer, timeout=config.policy.reward.timeout,
                                   exploration_noise=True, name="train", num_factors=config.num_factors, render=config.env.render,
                                   separate_her=config.data.her.separate_her, her_traj_length = config.data.her.her_traj_length, num_her_resamples=config.data.her.num_her_samples,
                                   use_lowest_post=config.data.her.use_lowest_post, goal_conditioned=goal_conditioned)
    train_collector.set_her_trajectory_check(policy.hindsight_filter)
    return train_collector, test_collector, buffer
